{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "import copy\n",
    "import json\n",
    "%matplotlib inline\n",
    "\n",
    "from deeppavlov.core.commands.utils import expand_path\n",
    "from deeppavlov.models.evolution.evolution_param_generator import ParamsEvolution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set here path to your config file, key main model and population size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "CONFIG_FILE = \"../../configs/evolution/evolve_intents_snips.json\"\n",
    "KEY_MAIN_MODEL = \"main\"\n",
    "POPULATION_SIZE = 2\n",
    "    \n",
    "with open(CONFIG_FILE, \"r\", encoding='utf8') as f:\n",
    "    basic_params = json.load(f)\n",
    "\n",
    "print(\"Considered basic config:\\n{}\".format(json.dumps(basic_params, indent=2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evolution = ParamsEvolution(population_size=POPULATION_SIZE,\n",
    "                            key_main_model=KEY_MAIN_MODEL,\n",
    "                            **basic_params)\n",
    "\n",
    "validate_best = evolution.get_value_from_config(\n",
    "    evolution.basic_config, list(evolution.find_model_path(\n",
    "        evolution.basic_config, \"validate_best\"))[0] + [\"validate_best\"])\n",
    "test_best = evolution.get_value_from_config(\n",
    "    evolution.basic_config, list(evolution.find_model_path(\n",
    "        evolution.basic_config, \"test_best\"))[0] + [\"test_best\"])\n",
    "\n",
    "TITLE = str(Path(evolution.get_value_from_config(\n",
    "    evolution.basic_config, evolution.main_model_path + [\"save_path\"])).stem)\n",
    "print(\"Title name for the considered evolution is `{}`.\".format(TITLE))\n",
    "\n",
    "data = pd.read_csv(str(expand_path(Path(evolution.get_value_from_config(\n",
    "    evolution.basic_config, evolution.main_model_path + [\"save_path\"])).joinpath(\n",
    "    \"result_table.tsv\"))), sep='\\t')\n",
    "print(\"Number of populations: {}.\".format(int(data.shape[0] / POPULATION_SIZE)))\n",
    "data.fillna(0., inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MEASURES = evolution.get_value_from_config(\n",
    "    evolution.basic_config, list(evolution.find_model_path(\n",
    "        evolution.basic_config, \"metrics\"))[0] + [\"metrics\"])\n",
    "\n",
    "for measure in MEASURES:\n",
    "    print(\"\\nMeasure: {}\".format(measure))\n",
    "    for data_type in [\"valid\", \"test\"]:\n",
    "        print(\"{}:\".format(data_type))\n",
    "        argmin = data[measure + \"_\" + data_type].argmin()\n",
    "        argmax = data[measure + \"_\" + data_type].argmax()\n",
    "        print(\"min for\\t{} model on\\t{} population\".format(argmin % POPULATION_SIZE,\n",
    "                                                             argmin // POPULATION_SIZE))\n",
    "        print(\"max for\\t{} model on\\t{} population\".format(argmax % POPULATION_SIZE,\n",
    "                                                             argmax // POPULATION_SIZE))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## If you want to plot measures depending on population colored by evolved measure value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "path_to_pics = expand_path(Path(evolution.get_value_from_config(\n",
    "    evolution.basic_config, evolution.main_model_path + [\"save_path\"])).joinpath(\"pics\"))\n",
    "path_to_pics.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "if validate_best:\n",
    "    evolve_metric = MEASURES[0] + \"_valid\"\n",
    "elif test_best:\n",
    "    evolve_metric = MEASURES[0] + \"_test\"\n",
    "    \n",
    "cmap = plt.get_cmap('rainbow')\n",
    "colors = [cmap(i) for i in np.linspace(0, 1, data.shape[0])]\n",
    "color_ids = np.argsort(data.loc[:, evolve_metric].values)\n",
    "\n",
    "ylims = [(0., 1)] * len(MEASURES)\n",
    "\n",
    "for metric, ylim in zip(MEASURES, ylims):\n",
    "    plt.figure(figsize=(12,6))\n",
    "    if validate_best:\n",
    "        for i in range(data.shape[0]):\n",
    "            plt.scatter(i // POPULATION_SIZE, \n",
    "                        data.loc[:, metric + \"_valid\"].values[i], \n",
    "                        c=colors[np.where(color_ids == i)[0][0]], alpha=0.5, marker='o')\n",
    "        plt.plot(np.arange(data.shape[0]//POPULATION_SIZE), \n",
    "             data.loc[:, metric + \"_valid\"].max() * np.ones(data.shape[0]//POPULATION_SIZE), \n",
    "             c=colors[-1])\n",
    "        plt.plot(np.arange(data.shape[0]//POPULATION_SIZE), \n",
    "             data.loc[:, metric + \"_valid\"].min() * np.ones(data.shape[0]//POPULATION_SIZE), \n",
    "             c=colors[0])\n",
    "    if test_best:\n",
    "        for i in range(data.shape[0]):\n",
    "            plt.scatter(i // POPULATION_SIZE, \n",
    "                        data.loc[:, metric + \"_test\"].values[i], \n",
    "                        c=colors[np.where(color_ids == i)[0][0]], alpha=0.5, marker='+', s=200)\n",
    "        plt.plot(np.arange(data.shape[0]//POPULATION_SIZE), \n",
    "             data.loc[:, metric + \"_test\"].max() * np.ones(data.shape[0]//POPULATION_SIZE), \"--\",\n",
    "             c=colors[-1])\n",
    "        plt.plot(np.arange(data.shape[0]//POPULATION_SIZE), \n",
    "             data.loc[:, metric + \"_test\"].min() * np.ones(data.shape[0]//POPULATION_SIZE), \"--\",\n",
    "             c=colors[0])\n",
    "    \n",
    "\n",
    "    plt.ylabel(metric, fontsize=20)\n",
    "    plt.xlabel(\"population\", fontsize=20)\n",
    "    plt.title(TITLE, fontsize=20)\n",
    "    plt.ylim(ylim[0], ylim[1])\n",
    "    plt.xticks(fontsize=20)\n",
    "    plt.yticks(fontsize=20)\n",
    "    plt.savefig(path_to_pics.joinpath(metric + \".png\"))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## If you want to plot measures depending on population colored by `evolution_model_id`\n",
    "\n",
    "####  That means model of the same `id` are of the same color."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params_dictionaries = []\n",
    "models_ids = []\n",
    "\n",
    "for i in range(data.shape[0]):\n",
    "    data.loc[i, \"params\"] = data.loc[i, \"params\"].replace(\"False\", \"false\")\n",
    "    data.loc[i, \"params\"] = data.loc[i, \"params\"].replace(\"True\", \"true\")\n",
    "    json_acceptable_string = data.loc[i, \"params\"].replace(\"'\", \"\\\"\")\n",
    "    d = json.loads(json_acceptable_string)\n",
    "    params_dictionaries.append(d)\n",
    "    models_ids.append(d[\"evolution_model_id\"])\n",
    "\n",
    "models_ids = np.array(models_ids)\n",
    "models_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "cmap = plt.get_cmap('rainbow')\n",
    "colors = [cmap(i) for i in np.linspace(0, 1, len(np.unique(models_ids)))]\n",
    "\n",
    "ylims = [(0., 1)] * len(MEASURES)\n",
    "\n",
    "for metric, ylim in zip(MEASURES, ylims):\n",
    "    plt.figure(figsize=(12,6))\n",
    "    if validate_best:\n",
    "        for i in range(data.shape[0]):\n",
    "            plt.scatter(i // POPULATION_SIZE, \n",
    "                        data.loc[:, metric + \"_valid\"].values[i], \n",
    "#                         c=colors[models_ids[i]], alpha=0.5, marker='o')\n",
    "                        c=colors[np.where(models_ids[i] == np.unique(models_ids))[0][0]], alpha=0.5, marker='o')\n",
    "            \n",
    "        plt.plot(np.arange(data.shape[0]//POPULATION_SIZE), \n",
    "             data.loc[:, metric + \"_valid\"].max() * np.ones(data.shape[0]//POPULATION_SIZE), \n",
    "             c=colors[-1])\n",
    "        plt.plot(np.arange(data.shape[0]//POPULATION_SIZE), \n",
    "             data.loc[:, metric + \"_valid\"].min() * np.ones(data.shape[0]//POPULATION_SIZE), \n",
    "             c=colors[0])\n",
    "    if test_best:\n",
    "        for i in range(data.shape[0]):\n",
    "            plt.scatter(i // POPULATION_SIZE, \n",
    "                        data.loc[:, metric + \"_test\"].values[i], \n",
    "                        c=colors[np.where(models_ids[i] == np.unique(models_ids))[0][0]], alpha=0.5, marker='+', s=200)\n",
    "        plt.plot(np.arange(data.shape[0]//POPULATION_SIZE), \n",
    "             data.loc[:, metric + \"_test\"].max() * np.ones(data.shape[0]//POPULATION_SIZE), \"--\",\n",
    "             c=colors[-1])\n",
    "        plt.plot(np.arange(data.shape[0]//POPULATION_SIZE), \n",
    "             data.loc[:, metric + \"_test\"].min() * np.ones(data.shape[0]//POPULATION_SIZE), \"--\",\n",
    "             c=colors[0])\n",
    "    \n",
    "\n",
    "    plt.ylabel(metric, fontsize=20)\n",
    "    plt.xlabel(\"population\", fontsize=20)\n",
    "    plt.title(TITLE, fontsize=20)\n",
    "    plt.ylim(ylim[0], ylim[1])\n",
    "    plt.xticks(fontsize=20)\n",
    "    plt.yticks(fontsize=20)\n",
    "    plt.savefig(path_to_pics.joinpath(metric + \"_colored_ids.png\"))\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "cmap = plt.get_cmap('rainbow')\n",
    "colors = [cmap(i) for i in np.linspace(0, 1, data.shape[0])]\n",
    "color_ids = np.argsort(data.loc[:, evolve_metric].values)\n",
    "\n",
    "for param_path in evolution.paths_to_evolving_params:\n",
    "    param_name = param_path[-1]\n",
    "    print(param_path, param_name)\n",
    "    \n",
    "    plt.figure(figsize=(12,12))\n",
    "    for i in range(data.shape[0]):\n",
    "        param_dict = evolution.get_value_from_config(evolution.basic_config, param_path)\n",
    "        if param_dict.get(\"evolve_range\") and param_dict.get(\"discrete\"):\n",
    "            plt.scatter(i // POPULATION_SIZE, \n",
    "                        evolution.get_value_from_config(params_dictionaries[i], param_path),\n",
    "#                         + (np.random.random() - 0.5) / 2,\n",
    "                        c=colors[np.where(color_ids == i)[0][0]], alpha=0.5)\n",
    "        elif param_dict.get(\"evolve_range\"):\n",
    "            plt.scatter(i // POPULATION_SIZE, \n",
    "                        evolution.get_value_from_config(params_dictionaries[i], param_path),\n",
    "                        c=colors[np.where(color_ids == i)[0][0]], alpha=0.5)\n",
    "        elif param_dict.get(\"evolve_choice\"):\n",
    "            values = np.array(param_dict.get(\"evolve_choice\"))\n",
    "            plt.scatter(i // POPULATION_SIZE, \n",
    "                        np.where(values == evolution.get_value_from_config(\n",
    "                            params_dictionaries[i], param_path))[0][0],\n",
    "                        c=colors[np.where(color_ids == i)[0][0]], alpha=0.5)\n",
    "            plt.yticks(np.arange(len(values)), values, fontsize=20)\n",
    "        elif param_dict.get(\"evolve_bool\"):\n",
    "            values = np.array([False, True])\n",
    "            plt.scatter(i // POPULATION_SIZE, \n",
    "                        np.where(values == evolution.get_value_from_config(\n",
    "                            params_dictionaries[i], param_path))[0][0],\n",
    "                        c=colors[np.where(color_ids == i)[0][0]], alpha=0.5)\n",
    "            plt.yticks(np.arange(len(values)), [\"False\", \"True\"], fontsize=20)\n",
    "\n",
    "    plt.ylabel(param_name, fontsize=20)\n",
    "    plt.xlabel(\"population\", fontsize=20)\n",
    "    plt.title(TITLE, fontsize=20)\n",
    "    plt.xticks(fontsize=20)\n",
    "    plt.yticks(fontsize=20)\n",
    "    plt.savefig(path_to_pics.joinpath(param_name + \".png\"))\n",
    "    plt.show()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python-deep36",
   "language": "python",
   "name": "deep36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
